// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//
#include <boost/algorithm/string/join.hpp>
#include <limits>
#include <pollux/expression/expr.h>
#include <pollux/expression/vector_function.h>
#include <pollux/functions/lib/lambda_function_util.h>

namespace kumo::pollux::functions {
    namespace {
        class ZipFunction : public exec::VectorFunction {
            static const auto kMinArity = 2;
            static const auto kMaxArity = 7;

        public:
            /// This class implements the zip function.
            ///
            /// DEFINITION:
            /// zip(ARRAY[T], ARRAY[U]) -> ARRAY(ROW[T,U])
            /// where we create a ROW[Ti, Ui] for every ith element in ARRAY[T], ARRAY[U].
            /// The smaller array is padded with nulls.
            ///
            /// IMPLEMENTATION:
            ///  1. The general idea is to create a new dictionary vector for each input
            ///  array vector and enumerate their indices to create a 1:1 mapping.
            ///  2. To do this, for each row we determine which is the largest Array
            ///  and subsequently pad the smaller arrays with nulls.
            ///  3. Then we take the resultant padded vectors together and create one ROW
            ///  Vector.
            ///  4. This forms the base to create the final output Array vector, whose
            ///  Arrays are the size of the largest input Array.
            ///
            ///  Note:
            ///   - We make no copy's of any constituent elements and are agnostic to
            ///   types.
            ///   - For compatibility with Presto a maximum arity of 7 is enforced.

            void apply(
                const SelectivityVector &rows,
                std::vector<VectorPtr> &args,
                const TypePtr &outputType,
                exec::EvalCtx &context,
                VectorPtr &result) const override {
                const vector_size_t numInputArrays = args.size();

                exec::DecodedArgs decodedArgs(rows, args, context);
                std::vector<const ArrayVector *> baseVectors(numInputArrays);
                std::vector<const vector_size_t *> rawSizes(numInputArrays);
                std::vector<const vector_size_t *> rawOffsets(numInputArrays);
                std::vector<const vector_size_t *> indices(numInputArrays);

                for (int i = 0; i < numInputArrays; i++) {
                    baseVectors[i] = decodedArgs.at(i)->base()->as<ArrayVector>();
                    rawSizes[i] = baseVectors[i]->rawSizes();
                    rawOffsets[i] = baseVectors[i]->rawOffsets();
                    indices[i] = decodedArgs.at(i)->indices();
                }

                // Size of elements in result vector.
                vector_size_t resultElementsSize = 0;
                auto *pool = context.pool();
                // This is true if for all rows, all the arrays within a row are the same
                // size.
                bool allSameSize = true;
                // This is true if for all rows, all the arrays within a row have the same
                // starting offset in their elements Vector.
                bool allSameOffsets = true;

                // Determine what the size of the resultant elements will be so we can
                // reserve enough space.
                auto getMaxArraySize = [&](vector_size_t row) -> vector_size_t {
                    vector_size_t maxSize = 0;
                    vector_size_t offset = -1;
                    for (int i = 0; i < numInputArrays; i++) {
                        vector_size_t size = rawSizes[i][indices[i][row]];
                        allSameSize &= i == 0 || maxSize == size;
                        allSameOffsets &= i == 0 || offset == rawOffsets[i][indices[i][row]];
                        maxSize = std::max(maxSize, size);
                        offset = rawOffsets[i][indices[i][row]];
                    }
                    return maxSize;
                };

                BufferPtr resultArraySizesBuffer = allocateSizes(rows.end(), pool);
                auto rawResultArraySizes =
                        resultArraySizesBuffer->asMutable<vector_size_t>();
                rows.applyToSelected([&](auto row) {
                    auto maxSize = getMaxArraySize(row);
                    resultElementsSize += maxSize;
                    rawResultArraySizes[row] = maxSize;
                });

                if (allSameSize && allSameOffsets) {
                    // This is true if all input vectors have the "flat" Array encoding.
                    bool allFlat = true;
                    for (const auto &arg: args) {
                        allFlat &= arg->encoding() == VectorEncoding::Simple::ARRAY;
                    }

                    if (allFlat) {
                        // Fast path if all input Vectors are flat and for all rows, all arrays
                        // within a row are the same size and start at the same offset.  In this
                        // case we don't have to add nulls, or decode the arrays, we can just
                        // pass in the element Vectors as is to be the fields of the output
                        // Rows.
                        std::vector<VectorPtr> elements;
                        elements.reserve(args.size());
                        // Since the offsets and sizes are all the same, using the minimum size
                        // is big enough to contain all elements, while also guaranteeing all
                        // child Vectors in the RowVector are at least this big.
                        vector_size_t minElementsSize =
                                std::numeric_limits<vector_size_t>::max();
                        for (const auto &arg: args) {
                            elements.push_back(arg->as<ArrayVector>()->elements());
                            minElementsSize = std::min(minElementsSize, elements.back()->size());
                        }

                        auto rowType = outputType->childAt(0);
                        auto row_vector = std::make_shared<RowVector>(
                            pool,
                            rowType,
                            BufferPtr(nullptr),
                            minElementsSize,
                            std::move(elements));

                        // Now convert these to an Array
                        auto array_vector = std::make_shared<ArrayVector>(
                            pool,
                            outputType,
                            BufferPtr(nullptr),
                            rows.end(),
                            baseVectors[0]->offsets(),
                            resultArraySizesBuffer,
                            std::move(row_vector));

                        context.moveOrCopyResult(array_vector, rows, result);

                        return;
                    }
                }

                // Create individual result vectors for each input Array vector.

                std::vector<BufferPtr> nestedResultIndices(numInputArrays);
                std::vector<BufferPtr> nestedResultNulls(numInputArrays);
                std::vector<vector_size_t *> rawNestedResultIndices(numInputArrays);
                std::vector<uint64_t *> rawNestedResultNulls(numInputArrays);

                for (int i = 0; i < numInputArrays; i++) {
                    nestedResultIndices[i] = allocate_indices(resultElementsSize, pool);
                    nestedResultNulls[i] = AlignedBuffer::allocate<bool>(
                        resultElementsSize, pool, bits::kNotNull);
                    rawNestedResultIndices[i] =
                            nestedResultIndices[i]->asMutable<vector_size_t>();
                    rawNestedResultNulls[i] = nestedResultNulls[i]->asMutable<uint64_t>();
                }

                const auto resultArraySize = rows.end();
                BufferPtr resultArrayOffsets = allocateOffsets(resultArraySize, pool);
                auto rawResultArrayOffsets = resultArrayOffsets->asMutable<vector_size_t>();

                // Create right offsets/indexes for the individual and final result arrays.
                int elementRow = 0;
                rows.applyToSelected([&](auto row) {
                    // Get the max size for that row.
                    auto maxArraySize = rawResultArraySizes[row];
                    rawResultArrayOffsets[row] = elementRow;

                    for (int i = 0; i < numInputArrays; i++) {
                        auto offset = rawOffsets[i][indices[i][row]];
                        auto size = rawSizes[i][indices[i][row]];
                        std::iota(
                            rawNestedResultIndices[i] + elementRow,
                            rawNestedResultIndices[i] + elementRow + size,
                            offset);
                        bits::fillBits(
                            rawNestedResultNulls[i],
                            elementRow + size,
                            elementRow + maxArraySize,
                            bits::kNull);
                    }
                    elementRow += maxArraySize;
                });

                // Create result dictionary vectors.
                std::vector<VectorPtr> resultDictionaryVectors(numInputArrays);

                for (int i = 0; i < numInputArrays; i++) {
                    resultDictionaryVectors[i] = BaseVector::wrap_in_dictionary(
                        nestedResultNulls[i],
                        nestedResultIndices[i],
                        resultElementsSize,
                        baseVectors[i]->elements());
                }

                auto rowType = outputType->childAt(0);
                auto row_vector = std::make_shared<RowVector>(
                    pool,
                    rowType,
                    BufferPtr(nullptr),
                    resultElementsSize,
                    resultDictionaryVectors);

                // Now convert these to an Array
                auto array_vector = std::make_shared<ArrayVector>(
                    pool,
                    outputType,
                    BufferPtr(nullptr),
                    rows.end(),
                    resultArrayOffsets,
                    resultArraySizesBuffer,
                    std::move(row_vector));

                context.moveOrCopyResult(array_vector, rows, result);
            }

            static std::vector<std::shared_ptr<exec::FunctionSignature> > signatures() {
                static const auto kAritySize = (kMaxArity - kMinArity) + 1;
                std::vector<std::shared_ptr<exec::FunctionSignature> > signatures;
                signatures.reserve(kAritySize);

                // Build all signatures from kMinArity to kMaxArity.
                for (int i = kMinArity; i <= kMaxArity; i++) {
                    auto builder = exec::FunctionSignatureBuilder();
                    std::vector<std::string> allTypeVars;
                    allTypeVars.reserve(i);

                    for (int j = 0; j < i; j++) {
                        allTypeVars.emplace_back(fmt::format("E{:02d}", j));
                        builder.typeVariable(allTypeVars.back());
                        builder.argumentType(fmt::format("array({})", allTypeVars.back()));
                    }

                    const auto returnType = melon::join(",", allTypeVars);
                    builder.returnType(fmt::format("array(row({}))", returnType));
                    signatures.emplace_back(builder.build());
                }

                return signatures;
            }
        };
    } // namespace

    POLLUX_DECLARE_VECTOR_FUNCTION(
        udf_zip,
        ZipFunction::signatures(),
        std::make_unique<ZipFunction>());
} // namespace kumo::pollux::functions
